"""Script to create (and optionally install) a `.whl` archive for KerasHub.

By default this will also create a shim package for `keras-nlp` (the old
package name) that provides a backwards compatible namespace.

Usage:

1. Create `.whl` files in `dist/` and `keras_nlp/dist/`:

```
python3 pip_build.py
```

2. Also install the new packages immediately after:

```
python3 pip_build.py --install
```

3. Only build keras-hub:

```
python3 pip_build.py --install --skip_keras_nlp
```
"""

import argparse
import datetime
import os
import pathlib
import re
import shutil

from keras_hub.src.version import __version__


def ignore_files(_, filenames):
    return [f for f in filenames if "_test" in f]


hub_package = "keras_hub"
nlp_package = "keras_nlp"
build_directory = "tmp_build_dir"
dist_directory = "dist"
to_copy = ["pyproject.toml", "README.md"]


def update_nightly_version(build_path, version):
    """Rewrite library version with the nightly package version."""
    date = datetime.datetime.now()
    new_version = re.sub(
        r"([0-9]+\.[0-9]+\.[0-9]+).*",  # Match version without suffix.
        r"\1.dev" + date.strftime("%Y%m%d%H%M"),  # Add dev{date} suffix.
        version,
    )

    version_file = build_path / hub_package / "src" / "version.py"
    version_contents = version_file.read_text()
    version_contents = re.sub(
        "\n__version__ = .*\n",
        f'\n__version__ = "{new_version}"\n',
        version_contents,
    )
    version_file.write_text(version_contents)
    return new_version


def update_nightly_name(build_path, pkg_name):
    """Rewrite library name with the nightly package name."""
    new_pkg_name = f"{pkg_name}-nightly"
    pyproj_file = build_path / "pyproject.toml"
    pyproj_contents = pyproj_file.read_text()
    pyproj_contents = pyproj_contents.replace(
        f'name = "{pkg_name}"', f'name = "{new_pkg_name}"'
    )
    pyproj_file.write_text(pyproj_contents)
    return new_pkg_name


def pin_keras_nlp_version(build_path, pkg_name, version):
    """Pin keras-nlp version and dependency to the keras-hub version."""
    pyproj_file = build_path / "pyproject.toml"
    pyproj_contents = pyproj_file.read_text()
    pyproj_contents = re.sub(
        "version = .*\n",
        f'version = "{version}"\n',
        pyproj_contents,
    )

    pyproj_contents = re.sub(
        "dependencies = .*\n",
        f'dependencies = ["{pkg_name}=={version}"]\n',
        pyproj_contents,
    )
    pyproj_file.write_text(pyproj_contents)


def copy_source_to_build_directory(src, dst, package):
    # Copy sources (`keras_hub/` directory and setup files) to build
    # directory
    shutil.copytree(src / package, dst / package, ignore=ignore_files)
    for fname in to_copy:
        shutil.copy(src / fname, dst / fname)


def build_wheel(build_path, dist_path, name, version):
    # Build the package
    os.chdir(build_path)
    os.system("python3 -m build")

    # Save the dist files generated by the build process
    if not os.path.exists(dist_path):
        os.mkdir(dist_path)
    for fpath in (build_path / dist_directory).glob("*.*"):
        shutil.copy(fpath, dist_path)

    # Check for the expected .whl file path
    name = name.replace("-", "_")
    whl_path = dist_path / f"{name}-{version}-py3-none-any.whl"
    if not os.path.exists(whl_path):
        raise ValueError(f"Could not find whl {whl_path}")
    print(f"Build successful. Wheel file available at {whl_path}")
    return whl_path


def build(root_path, is_nightly=False, keras_nlp=True):
    if os.path.exists(build_directory):
        raise ValueError(f"Directory already exists: {build_directory}")

    try:
        whls = []
        build_path = root_path / build_directory
        dist_path = root_path / dist_directory
        os.mkdir(build_path)
        copy_source_to_build_directory(root_path, build_path, hub_package)

        version = __version__
        pkg_name = hub_package.replace("_", "-")
        if is_nightly:
            version = update_nightly_version(build_path, version)
            pkg_name = update_nightly_name(build_path, pkg_name)
            assert "dev" in version, "Version should contain dev"
            assert "nightly" in pkg_name, "Name should contain nightly"

        whl = build_wheel(build_path, dist_path, pkg_name, version)
        whls.append(whl)

        if keras_nlp:
            build_path = root_path / build_directory / nlp_package
            dist_path = root_path / nlp_package / dist_directory
            copy_source_to_build_directory(
                root_path / nlp_package, build_path, nlp_package
            )

            pin_keras_nlp_version(build_path, pkg_name, version)
            nlp_pkg_name = nlp_package.replace("_", "-")
            if is_nightly:
                nlp_pkg_name = update_nightly_name(build_path, nlp_pkg_name)
                assert "dev" in version, "Version should contain dev"
                assert "nightly" in nlp_pkg_name, "Name should contain nightly"

            whl = build_wheel(build_path, dist_path, nlp_pkg_name, version)
            whls.append(whl)

        return whls
    finally:
        # Clean up: remove the build directory (no longer needed)
        os.chdir(root_path)
        shutil.rmtree(root_path / build_directory)


def install_whl(whls):
    for path in whls:
        print(f"Installing wheel file: {path}")
        os.system(f"pip3 install {path} --force-reinstall --no-dependencies")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--install",
        action="store_true",
        help="Whether to install the generated wheel file.",
    )
    parser.add_argument(
        "--nightly",
        action="store_true",
        help="Whether to generate nightly wheel file.",
    )
    parser.add_argument(
        "--skip_keras_nlp",
        action="store_true",
        help="Whether to build the keras-nlp shim package.",
    )
    args = parser.parse_args()
    root_path = pathlib.Path(__file__).parent.resolve()
    whls = build(root_path, args.nightly, not args.skip_keras_nlp)
    if whls and args.install:
        install_whl(whls)
